{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Iterators\n",
    "(c) Deniz Yuret, 2019\n",
    "\n",
    "* Objective: Learning how to construct and use Julia iterators.\n",
    "* Reading: [Interfaces](https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-iteration-1),  [Collections](https://docs.julialang.org/en/v1/base/collections/#lib-collections-iteration-1), [Iteration Utilities](https://docs.julialang.org/en/v1/base/iterators) and [Generator expressions](https://docs.julialang.org/en/v1/manual/arrays/#Generator-Expressions-1) in the Julia manual.\n",
    "* Prerequisites: [minibatch, Data](https://github.com/denizyuret/Knet.jl/blob/master/src/data.jl) from the [MNIST notebook](20.mnist.ipynb)\n",
    "* New functions: \n",
    "[first](https://docs.julialang.org/en/v1/base/collections/#Base.first), \n",
    "[collect](https://docs.julialang.org/en/v1/base/collections/#Base.collect-Tuple{Any}), \n",
    "[take](https://docs.julialang.org/en/v1/base/iterators/#Base.Iterators.take), \n",
    "[drop](https://docs.julialang.org/en/v1/base/iterators/#Base.Iterators.drop), \n",
    "[cycle](https://docs.julialang.org/en/v1/base/iterators/#Base.Iterators.cycle),\n",
    "[ncycle](https://juliacollections.github.io/IterTools.jl/stable/#ncycle(xs,-n)-1),\n",
    "[takenth](https://juliacollections.github.io/IterTools.jl/stable/#takenth(xs,-n)-1),\n",
    "[takewhile](https://juliacollections.github.io/IterTools.jl/stable/#takewhile(cond,-xs)-1),\n",
    "[Stateful](https://docs.julialang.org/en/v1/base/iterators/#Base.Iterators.Stateful), \n",
    "[iterate](https://docs.julialang.org/en/v1/base/collections/#lib-collections-iteration-1)\n",
    "\n",
    "The `minibatch` function returns a `Knet.Data` object implemented as a Julia iterator that generates (x,y) minibatches. Iterators are lazy objects that only generate their next element when asked. This has the advantage of not wasting time and memory trying to create and store all the elements at once. We can even have infinite iterators! The training algorithms in Knet are also implemented as iterators so that:\n",
    "1. We can monitor and report the training loss\n",
    "2. We can take snapshots of the model during training\n",
    "3. We can pause/terminate training when necessary\n",
    "\n",
    "Here are some things Julia can do with iterators:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# Set display width, load packages, import symbols\n",
    "ENV[\"COLUMNS\"]=72\n",
    "using Base.Iterators: take, drop, cycle, Stateful\n",
    "using IterTools: ncycle, takenth, takewhile\n",
    "using MLDatasets: MNIST\n",
    "using Knet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100-element Knet.Train20.Data{Tuple{Array{Float32,3},Array{Int64,1}}}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load MNIST data as an iterator of (x,y) minibatches\n",
    "xtst,ytst = MNIST.testdata(Float32)\n",
    "dtst = minibatch(xtst, ytst, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(\"28×28×100 Array{Float32,3}\", \"100-element Array{Int64,1}\")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# We can peek at the first element using first()\n",
    "summary.(first(dtst))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n = 100\n"
     ]
    }
   ],
   "source": [
    "# Iterators can be used in for loops\n",
    "# Let's count the elements in dtst:\n",
    "n = 0\n",
    "for (x,y) in dtst; global n += 1; end\n",
    "@show n;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"100-element Array{Tuple{Array{Float32,3},Array{Int64,1}},1}\""
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Iterators can be converted to arrays using `collect` \n",
    "# (don't do this unless necessary, it just wastes memory. Use a for loop instead)\n",
    "collect(dtst) |> summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n = 500\n"
     ]
    }
   ],
   "source": [
    "# We can generate an iterator for multiple epochs using `ncycle`\n",
    "# (an epoch is a single pass over the dataset)\n",
    "n = 0\n",
    "for (x,y) in ncycle(dtst,5); global n += 1; end\n",
    "@show n;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n = 20\n"
     ]
    }
   ],
   "source": [
    "# We can generate partial epochs using `take` which takes the first n elements\n",
    "n = 0\n",
    "for (x,y) in take(dtst,20); global n += 1; end\n",
    "@show n;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n = 80\n"
     ]
    }
   ],
   "source": [
    "# We can also generate partial epochs using `drop` which drops the first n elements\n",
    "n = 0\n",
    "for (x,y) in drop(dtst,20); global n += 1; end\n",
    "@show n;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n = 1235\n"
     ]
    }
   ],
   "source": [
    "# We can repeat forever using `cycle`\n",
    "# You do not want to collect a cycle or run a for loop without break! \n",
    "n = 0\n",
    "for (x,y) in cycle(dtst); (global n += 1) > 1234 && break; end\n",
    "@show n;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n = 56\n"
     ]
    }
   ],
   "source": [
    "# We can repeat until a condition is met using `takewhile`\n",
    "# This is useful to train until convergence\n",
    "n = 0\n",
    "for (x,y) in takewhile(x->(n<56), dtst); global n += 1; end\n",
    "@show n;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n = 16\n"
     ]
    }
   ],
   "source": [
    "# We can take every nth element using `takenth`\n",
    "# This is useful to report progress every nth iteration\n",
    "n = 0\n",
    "for (x,y) in takenth(dtst,6); global n += 1; end\n",
    "@show n;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×100 LinearAlgebra.Adjoint{Float32,Array{Float32,1}}:\n",
       " 7990.35  7842.33  8162.68  7692.77  …  8494.0  7361.33  8643.01"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# We can construct new iterators using [Generator expressions](https://docs.julialang.org/en/v1/manual/arrays/#Generator-Expressions-1)\n",
    "# The following example constructs an iterator over the x norms in a dataset:\n",
    "xnorm(data) = (sum(abs2,x) for (x,y) in data)\n",
    "collect(xnorm(dtst))'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n = 100\n"
     ]
    }
   ],
   "source": [
    "# Every iterator implements the `iterate` function which returns\n",
    "# the next element and state (or nothing if no elements left).\n",
    "# Here is how the for loop for dtst is implemented:\n",
    "n = 0; next = iterate(dtst)\n",
    "while next != nothing\n",
    "    ((_x,_y), state) = next\n",
    "    global n += 1\n",
    "    global next = iterate(dtst,state)\n",
    "end\n",
    "@show n;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×100 LinearAlgebra.Adjoint{Any,Array{Any,1}}:\n",
       " 7990.35  7842.33  8162.68  7692.77  …  8494.0  7361.33  8643.01"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# You can define your own iterator by declaring a new type and overriding the `iterate` method.\n",
    "# Here is another way to define an iterator over the x norms in a dataset:\n",
    "struct Xnorm; itr; end\n",
    "\n",
    "function Base.iterate(f::Xnorm, s...)\n",
    "    next = iterate(f.itr, s...)\n",
    "    next === nothing && return nothing\n",
    "    ((x,y),state) = next\n",
    "    return sum(abs2,x), state\n",
    "end\n",
    "\n",
    "Base.length(f::Xnorm) = length(f.itr) # collect needs this\n",
    "\n",
    "collect(Xnorm(dtst))'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[7, 2, 1, 0, 4]\n",
      "[7, 2, 1, 0, 4]\n",
      "[7, 2, 1, 0, 4]\n",
      "[6, 0, 5, 4, 9]\n"
     ]
    }
   ],
   "source": [
    "# We can make an iterator `Stateful` so it remembers where it left off.\n",
    "# (by default iterators start from the beginning)\n",
    "dtst1 = dtst            # dtst1 will start from beginning every time\n",
    "dtst2 = Stateful(dtst)  # dtst2 will remember where we left off\n",
    "for (x,y) in dtst1; println(Int.(y[1:5])); break; end\n",
    "for (x,y) in dtst1; println(Int.(y[1:5])); break; end\n",
    "for (x,y) in dtst2; println(Int.(y[1:5])); break; end\n",
    "for (x,y) in dtst2; println(Int.(y[1:5])); break; end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[7, 2, 1, 0, 4]\n",
      "[7, 2, 1, 0, 4]\n",
      "[3, 3, 8, 0, 4]\n",
      "[0, 1, 4, 2, 3]\n"
     ]
    }
   ],
   "source": [
    "# We can shuffle instances at every epoch using the keyword argument `shuffle=true`\n",
    "# (by default elements are generated in the same order)\n",
    "dtst1 = minibatch(xtst,ytst,100)              # dtst1 iterates in the same order\n",
    "dtst2 = minibatch(xtst,ytst,100,shuffle=true) # dtst2 shuffles each time\n",
    "for (x,y) in dtst1; println(Int.(y[1:5])); break; end\n",
    "for (x,y) in dtst1; println(Int.(y[1:5])); break; end\n",
    "for (x,y) in dtst2; println(Int.(y[1:5])); break; end\n",
    "for (x,y) in dtst2; println(Int.(y[1:5])); break; end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "Julia 1.5.0",
   "language": "julia",
   "name": "julia-1.5"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
